{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# call graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm import relay"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## test_callgraph_construct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #AA22FF\">@g1</span>(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x: Tensor[(<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">3</span>), float32], <span style=\"color: #AA22FF; font-weight: bold\">%</span>y: Tensor[(<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">3</span>), float32]) {\n",
       "  add(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x, <span style=\"color: #AA22FF; font-weight: bold\">%</span>y)\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "mod = tvm.IRModule({})\n",
    "x = relay.var(\"x\", shape=(2, 3))\n",
    "y = relay.var(\"y\", shape=(2, 3))\n",
    "mod[\"g1\"] = relay.Function([x, y], x + y)\n",
    "mod.show()\n",
    "call_graph = relay.analysis.CallGraph(mod)\n",
    "assert \"g1\" in str(call_graph)\n",
    "tvm.ir.assert_structural_equal(mod, call_graph.module)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Call graph node: g1 at: 0x5647c7184360,  #refs = 0\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(call_graph)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 其他"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_print_element():\n",
    "    mod = tvm.IRModule({})\n",
    "    x0 = relay.var(\"x0\", shape=(2, 3))\n",
    "    y0 = relay.var(\"y0\", shape=(2, 3))\n",
    "    mod[\"g0\"] = relay.Function([x0, y0], x0 + y0)\n",
    "    x1 = relay.var(\"x1\", shape=(2, 3))\n",
    "    y1 = relay.var(\"y1\", shape=(2, 3))\n",
    "    mod[\"g1\"] = relay.Function([x1, y1], x1 - y1)\n",
    "    call_graph = relay.analysis.CallGraph(mod)\n",
    "\n",
    "    assert \"#refs = 0\" in str(call_graph.print_var(\"g0\"))\n",
    "    assert \"#refs = 0\" in str(call_graph.print_var(\"g1\"))\n",
    "\n",
    "\n",
    "def test_global_call_count():\n",
    "    mod = tvm.IRModule({})\n",
    "    x0 = relay.var(\"x0\", shape=(2, 3))\n",
    "    y0 = relay.var(\"y0\", shape=(2, 3))\n",
    "    g0 = relay.GlobalVar(\"g0\")\n",
    "    mod[g0] = relay.Function([x0, y0], x0 + y0)\n",
    "    x1 = relay.var(\"x1\", shape=(2, 3))\n",
    "    y1 = relay.var(\"y1\", shape=(2, 3))\n",
    "    g1 = relay.GlobalVar(\"g1\")\n",
    "    mod[g1] = relay.Function([x1, y1], g0(x1, y1))\n",
    "    call_graph = relay.analysis.CallGraph(mod)\n",
    "\n",
    "    p0 = relay.var(\"p0\", shape=(2, 3))\n",
    "    p1 = relay.var(\"p1\", shape=(2, 3))\n",
    "    func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1))\n",
    "    mod[\"main\"] = func\n",
    "    call_graph = relay.analysis.CallGraph(mod)\n",
    "\n",
    "    assert call_graph.global_call_count(g0) == 0\n",
    "    assert call_graph.global_call_count(g1) == 1\n",
    "    assert call_graph.global_call_count(\"main\") == 2\n",
    "\n",
    "\n",
    "def test_ref_count():\n",
    "    mod = tvm.IRModule({})\n",
    "    x0 = relay.var(\"x0\", shape=(2, 3))\n",
    "    y0 = relay.var(\"y0\", shape=(2, 3))\n",
    "    g0 = relay.GlobalVar(\"g0\")\n",
    "    mod[g0] = relay.Function([x0, y0], x0 + y0)\n",
    "    x1 = relay.var(\"x1\", shape=(2, 3))\n",
    "    y1 = relay.var(\"y1\", shape=(2, 3))\n",
    "    g1 = relay.GlobalVar(\"g1\")\n",
    "    mod[g1] = relay.Function([x1, y1], x1 - y1)\n",
    "    call_graph = relay.analysis.CallGraph(mod)\n",
    "\n",
    "    p0 = relay.var(\"p0\", shape=(2, 3))\n",
    "    p1 = relay.var(\"p1\", shape=(2, 3))\n",
    "    func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1))\n",
    "    mod[\"main\"] = func\n",
    "    call_graph = relay.analysis.CallGraph(mod)\n",
    "\n",
    "    assert call_graph.ref_count(g0) == 1\n",
    "    assert call_graph.ref_count(g1) == 1\n",
    "    assert call_graph.ref_count(\"main\") == 0\n",
    "\n",
    "\n",
    "def test_nested_ref():\n",
    "    mod = tvm.IRModule({})\n",
    "    x0 = relay.var(\"x0\", shape=(2, 3))\n",
    "    y0 = relay.var(\"y0\", shape=(2, 3))\n",
    "    g0 = relay.GlobalVar(\"g0\")\n",
    "    mod[g0] = relay.Function([x0, y0], x0 + y0)\n",
    "    x1 = relay.var(\"x1\", shape=(2, 3))\n",
    "    y1 = relay.var(\"y1\", shape=(2, 3))\n",
    "    g1 = relay.GlobalVar(\"g1\")\n",
    "    mod[g1] = relay.Function([x1, y1], g0(x1, y1))\n",
    "    call_graph = relay.analysis.CallGraph(mod)\n",
    "\n",
    "    p0 = relay.var(\"p0\", shape=(2, 3))\n",
    "    p1 = relay.var(\"p1\", shape=(2, 3))\n",
    "    func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1))\n",
    "    mod[\"main\"] = func\n",
    "    call_graph = relay.analysis.CallGraph(mod)\n",
    "\n",
    "    assert call_graph.ref_count(g0) == 2\n",
    "    assert call_graph.ref_count(g1) == 1\n",
    "    assert call_graph.ref_count(\"main\") == 0\n",
    "\n",
    "\n",
    "def test_recursive_func():\n",
    "    mod = tvm.IRModule({})\n",
    "\n",
    "    x = relay.var(\"x\", shape=[], dtype=\"int32\")\n",
    "    fn0 = relay.Function([x], x)\n",
    "    gx = relay.GlobalVar(\"gx\")\n",
    "    mod[gx] = fn0\n",
    "\n",
    "    sum_up = relay.GlobalVar(\"sum_up\")\n",
    "    i = relay.var(\"i\", shape=[], dtype=\"int32\")\n",
    "    sb = relay.ScopeBuilder()\n",
    "    with sb.if_scope(relay.equal(i, relay.const(0, dtype=\"int32\"))):\n",
    "        sb.ret(i)\n",
    "    with sb.else_scope():\n",
    "        one_less = relay.subtract(i, relay.const(1, dtype=\"int32\"))\n",
    "        global_call = gx(i)\n",
    "        rec_call = relay.Call(sum_up, [one_less]) + global_call\n",
    "        sb.ret(relay.add(rec_call, i))\n",
    "    func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], \"int32\"))\n",
    "    func = func.with_attr(\"Compiler\", \"a\")\n",
    "    mod[sum_up] = func\n",
    "    iarg = relay.var(\"i\", shape=[], dtype=\"int32\")\n",
    "    mod[\"main\"] = relay.Function([iarg], sum_up(iarg))\n",
    "    call_graph = relay.analysis.CallGraph(mod)\n",
    "\n",
    "    assert call_graph.is_recursive(sum_up)\n",
    "    assert call_graph.ref_count(sum_up) == 2\n",
    "    assert call_graph.ref_count(gx) == 1\n",
    "    assert call_graph.ref_count(\"main\") == 0"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
